{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "51466c8d-8ce4-4b3d-be4e-18fdbeda5f53",
   "metadata": {},
   "source": [
    "# How to edit graph state\n",
    "\n",
    "When creating LangGraph agents, it is often nice to add a human-in-the-loop component.\n",
    "This can be helpful when giving them access to tools.\n",
    "Often in these situations you may want to edit the graph state before continuing (for example, to edit what tool is being called, or how it is being called).\n",
    "\n",
    "This can be in several ways, but the primary supported way is to add an \"interrupt\" before a node is executed.\n",
    "This interrupts execution at that node.\n",
    "You can then use `update_state` to update the state, and then resume from that spot to continue.  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cbd446a-808f-4394-be92-d45ab818953c",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First we need to install the packages required"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af4ce0ba-7596-4e5f-8bf8-0b0bd6e62833",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture --no-stderr\n",
    "%pip install --quiet -U langgraph langchain_anthropic"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0abe11f4-62ed-4dc4-8875-3db21e260d1d",
   "metadata": {},
   "source": [
    "Next, we need to set API keys for Anthropic (the LLM we will use)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c903a1cf-2977-4e2d-ad7d-8b3946821d89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ANTHROPIC_API_KEY:  ········\n"
     ]
    }
   ],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "\n",
    "def _set_env(var: str):\n",
    "    if not os.environ.get(var):\n",
    "        os.environ[var] = getpass.getpass(f\"{var}: \")\n",
    "\n",
    "\n",
    "_set_env(\"ANTHROPIC_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0ed46a8-effe-4596-b0e1-a6a29ee16f5c",
   "metadata": {},
   "source": [
    "Optionally, we can set API key for [LangSmith tracing](https://smith.langchain.com/), which will give us best-in-class observability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "95e25aec-7c9f-4a63-b143-225d0e9a79c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
    "_set_env(\"LANGCHAIN_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3333b771",
   "metadata": {},
   "source": [
    "## Build the agent\n",
    "\n",
    "We can now build the agent. We will build a relatively simple ReAct-style agent that does tool calling. We will use Anthropic's models and a fake tool (just for demo purposes)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6098e5cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up the tool\n",
    "from langchain_anthropic import ChatAnthropic\n",
    "from langchain_core.tools import tool\n",
    "from langgraph.graph import MessagesState\n",
    "from langgraph.prebuilt import ToolNode\n",
    "from langgraph.graph import END, StateGraph\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "\n",
    "@tool\n",
    "def search(query: str):\n",
    "    \"\"\"Call to surf the web.\"\"\"\n",
    "    # This is a placeholder for the actual implementation\n",
    "    # Don't let the LLM know this though 😊\n",
    "    return [\n",
    "        \"It's sunny in San Francisco, but you better look out if you're a Gemini 😈.\"\n",
    "    ]\n",
    "\n",
    "tools = [search]\n",
    "tool_node = ToolNode(tools)\n",
    "\n",
    "# Set up the model\n",
    "\n",
    "model = ChatAnthropic(model=\"claude-3-5-sonnet-20240620\")\n",
    "model = model.bind_tools(tools)\n",
    "\n",
    "\n",
    "# Define nodes and conditional edges\n",
    "\n",
    "\n",
    "# Define the function that determines whether to continue or not\n",
    "def should_continue(state):\n",
    "    messages = state[\"messages\"]\n",
    "    last_message = messages[-1]\n",
    "    # If there is no function call, then we finish\n",
    "    if not last_message.tool_calls:\n",
    "        return \"end\"\n",
    "    # Otherwise if there is, we continue\n",
    "    else:\n",
    "        return \"continue\"\n",
    "\n",
    "\n",
    "# Define the function that calls the model\n",
    "def call_model(state):\n",
    "    messages = state[\"messages\"]\n",
    "    response = model.invoke(messages)\n",
    "    # We return a list, because this will get added to the existing list\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "# Define a new graph\n",
    "workflow = StateGraph(MessagesState)\n",
    "\n",
    "# Define the two nodes we will cycle between\n",
    "workflow.add_node(\"agent\", call_model)\n",
    "workflow.add_node(\"action\", tool_node)\n",
    "\n",
    "# Set the entrypoint as `agent`\n",
    "# This means that this node is the first one called\n",
    "workflow.set_entry_point(\"agent\")\n",
    "\n",
    "# We now add a conditional edge\n",
    "workflow.add_conditional_edges(\n",
    "    # First, we define the start node. We use `agent`.\n",
    "    # This means these are the edges taken after the `agent` node is called.\n",
    "    \"agent\",\n",
    "    # Next, we pass in the function that will determine which node is called next.\n",
    "    should_continue,\n",
    "    # Finally we pass in a mapping.\n",
    "    # The keys are strings, and the values are other nodes.\n",
    "    # END is a special node marking that the graph should finish.\n",
    "    # What will happen is we will call `should_continue`, and then the output of that\n",
    "    # will be matched against the keys in this mapping.\n",
    "    # Based on which one it matches, that node will then be called.\n",
    "    {\n",
    "        # If `tools`, then we call the tool node.\n",
    "        \"continue\": \"action\",\n",
    "        # Otherwise we finish.\n",
    "        \"end\": END,\n",
    "    },\n",
    ")\n",
    "\n",
    "# We now add a normal edge from `tools` to `agent`.\n",
    "# This means that after `tools` is called, `agent` node is called next.\n",
    "workflow.add_edge(\"action\", \"agent\")\n",
    "\n",
    "# Set up memory\n",
    "memory = MemorySaver()\n",
    "\n",
    "# Finally, we compile it!\n",
    "# This compiles it into a LangChain Runnable,\n",
    "# meaning you can use it as you would any other runnable\n",
    "\n",
    "# We add in `interrupt_before=[\"action\"]`\n",
    "# This will add a breakpoint before the `action` node is called\n",
    "app = workflow.compile(checkpointer=memory, interrupt_before=[\"action\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a1b56c5-bd61-4192-8bdb-458a1e9f0159",
   "metadata": {},
   "source": [
    "## Interacting with the Agent\n",
    "\n",
    "We can now interact with the agent and see that it stops before calling a tool.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cfd140f0-a5a6-4697-8115-322242f197b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "search for the weather in sf now\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "[{'text': \"Certainly! I'll search for the current weather in San Francisco for you. Let me use the search function to find this information.\", 'type': 'text'}, {'id': 'toolu_019xfFeuGu6taKqA8unr6VxP', 'input': {'query': 'current weather in San Francisco'}, 'name': 'search', 'type': 'tool_use'}]\n",
      "Tool Calls:\n",
      "  search (toolu_019xfFeuGu6taKqA8unr6VxP)\n",
      " Call ID: toolu_019xfFeuGu6taKqA8unr6VxP\n",
      "  Args:\n",
      "    query: current weather in San Francisco\n"
     ]
    }
   ],
   "source": [
    "from langchain_core.messages import HumanMessage\n",
    "\n",
    "thread = {\"configurable\": {\"thread_id\": \"3\"}}\n",
    "inputs = [HumanMessage(content=\"search for the weather in sf now\")]\n",
    "for event in app.stream({\"messages\": inputs}, thread, stream_mode=\"values\"):\n",
    "    event[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78e3f5b9-9700-42b1-863f-c404861f8620",
   "metadata": {},
   "source": [
    "**Edit**\n",
    "\n",
    "We can now update the state accordingly. Let's modify the tool call to have the query `\"current weather in SF\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1aa7b1b9-9322-4815-bc0d-eb083870ac15",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'configurable': {'thread_id': '3',\n",
       "  'thread_ts': '1ef355a9-ab59-6102-8002-56ee1c266b09'}}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First, lets get the current state\n",
    "current_state = app.get_state(thread)\n",
    "\n",
    "# Let's now get the last message in the state\n",
    "# This is the one with the tool calls that we want to update\n",
    "last_message = current_state.values['messages'][-1]\n",
    "\n",
    "# Let's now update the args for that tool call\n",
    "last_message.tool_calls[0]['args'] = {'query': 'current weather in SF'}\n",
    "\n",
    "# Let's now call `update_state` to pass in this message in the `messages` key\n",
    "# This will get treated as any other update to the state\n",
    "# It will get passed to the reducer function for the `messages` key\n",
    "# That reducer function will use the ID of the message to update it\n",
    "# It's important that it has the right ID! Otherwise it would get appended\n",
    "# as a new message\n",
    "app.update_state(thread, {\"messages\": last_message})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dcc5457-1ba1-4cba-ac41-da5c67cc67e5",
   "metadata": {},
   "source": [
    "Let's now check the current state of the app to make sure it got updated accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a3fcf2bd-f881-49fe-b20e-ad16e6819bc6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'name': 'search',\n",
       "  'args': {'query': 'current weather in SF'},\n",
       "  'id': 'toolu_019xfFeuGu6taKqA8unr6VxP'}]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "current_state = app.get_state(thread).values['messages'][-1].tool_calls\n",
    "current_state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bca3814-db08-4b0b-8c0c-95b6c5440c81",
   "metadata": {},
   "source": [
    "**Resume**\n",
    "\n",
    "We can now call the agent again with no inputs to continue, ie. run the tool as requested. We can see from the logs that it passes in the update args to the tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "51923913-20f7-4ee1-b9ba-d01f5fb2869b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================\u001b[1m Tool Message \u001b[0m=================================\n",
      "Name: search\n",
      "\n",
      "[\"It's sunny in San Francisco, but you better look out if you're a Gemini \\ud83d\\ude08.\"]\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Based on the search results, I can provide you with information about the current weather in San Francisco (SF):\n",
      "\n",
      "The weather in San Francisco is currently sunny. This means it's a clear day with plenty of sunshine.\n",
      "\n",
      "However, I should note that the search result includes an unusual additional comment about astrology, which isn't typically part of a standard weather report. The mention of Geminis is likely not relevant to the actual weather conditions.\n",
      "\n",
      "If you'd like more specific details about the temperature, humidity, or wind conditions in San Francisco, I'd be happy to perform another search with a more focused query. Just let me know if you need any additional weather information!\n"
     ]
    }
   ],
   "source": [
    "for event in app.stream(None, thread, stream_mode=\"values\"):\n",
    "    event[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78780afe-409d-46cd-a734-e82538cdd8de",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
